{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analysis of ODE-LSTM Results\n",
    "\n",
    "This notebook creates the comparison between ODE-LSTM and MTS-LSTM from the paper. \n",
    "To reproduce the contents of this notebook, you need to download the models' predictions (or create them yourself) into the folder `BASE_DIR`\n",
    "\n",
    "`README.md` contains information on where to obtain the required data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import xarray as xr\n",
    "from scipy.stats import wilcoxon\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from neuralhydrology.evaluation.metrics import calculate_metrics\n",
    "\n",
    "BASE_DIR = Path('/home/mgauch/mts-lstm/results/odelstm')\n",
    "\n",
    "basins = ['01022500', '02064000', '02374500', '05593575', '06404000', '06889500', '08190000', '09352900', '11481200', '12189500']\n",
    "metric = 'NSE'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "### Load predictions and metrics for each model ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (a) train on 1D+12H, evaluate on 1H (dividing 12H-predictions by 12)\n",
    "# (b) train on 1H+3H, evaluate on 1D (aggregating every 8 3H-predictions)\n",
    "# (c) train on 1H+1D, evaluate on 1H+1D\n",
    "a_mtslstm, b_mtslstm = {}, {}\n",
    "a_odelstm, b_odelstm = {}, {}\n",
    "for b in basins:\n",
    "    # MTS-LSTM predictions (single-basin)\n",
    "    a_mtslstm[b] = pickle.load(open(BASE_DIR / f'ensemble_mtslstm_a_{b}.p', 'rb'))[b]\n",
    "    b_mtslstm[b] = pickle.load(open(BASE_DIR / f'ensemble_mtslstm_b_{b}.p', 'rb'))[b]\n",
    "\n",
    "    # ODE-LSTM (single-basin)\n",
    "    a_odelstm[b] = pickle.load(open(BASE_DIR / f'ensemble_odelstm_a_{b}.p', 'rb'))[b]\n",
    "    b_odelstm[b] = pickle.load(open(BASE_DIR / f'ensemble_odelstm_b_{b}.p', 'rb'))[b]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Dis-)aggregate MTS-LSTM predictions to missing timescales and calculate metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef352599fed147aebf59b1797948d24c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mgauch/miniconda3/envs/pytorch/lib/python3.7/site-packages/xarray/core/nanops.py:142: RuntimeWarning: Mean of empty slice\n",
      "  return np.nanmean(a, axis=axis, dtype=dtype)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for basin in tqdm(basins):\n",
    "    a_mtslstm[basin]['1H']['xr'] = a_mtslstm[basin]['12H']['xr'].resample({'datetime': '1H'}).ffill()\n",
    "    b_mtslstm[basin]['1D']['xr'] = b_mtslstm[basin]['3H']['xr'].resample({'datetime': '1D'}).mean()\n",
    "    a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs'] = b_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs']\n",
    "    b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs'] = a_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs']\n",
    "\n",
    "    a_mtslstm[basin]['1H'][f'{metric}_1H'] = calculate_metrics(a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs'],\n",
    "                                                               a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_sim'],\n",
    "                                                               [metric], resolution='1H')[metric]\n",
    "    b_mtslstm[basin]['1D'][f'{metric}_1D'] = calculate_metrics(b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs'],\n",
    "                                                               b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_sim'],\n",
    "                                                               [metric], resolution='1D')[metric]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_df(dict_mtslstm, dict_odelstm):\n",
    "    df = pd.DataFrame.from_dict({**{('MTS-LSTM', basin): {f'{metric}_{f}': dict_mtslstm[basin][f][f'{metric}_{f}']\n",
    "                                                          for f in dict_mtslstm[basin]} for basin in basins}, \n",
    "                                 **{('ODE-LSTM', basin): {f'{metric}_{f}': dict_odelstm[basin][f][f'{metric}_{f}']\n",
    "                                                          for f in dict_odelstm[basin]} for basin in basins}},\n",
    "                               orient='index')\n",
    "    return df\n",
    "\n",
    "a_df = to_df(a_mtslstm, a_odelstm)\n",
    "b_df = to_df(b_mtslstm, b_odelstm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(A) Train on 1D, 12H. Evaluate on 1H. Medians:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NSE_1D</th>\n",
       "      <th>NSE_12H</th>\n",
       "      <th>NSE_1H</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MTS-LSTM</th>\n",
       "      <td>0.726355</td>\n",
       "      <td>0.734082</td>\n",
       "      <td>0.705742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ODE-LSTM</th>\n",
       "      <td>0.719632</td>\n",
       "      <td>0.705877</td>\n",
       "      <td>0.639060</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            NSE_1D   NSE_12H    NSE_1H\n",
       "MTS-LSTM  0.726355  0.734082  0.705742\n",
       "ODE-LSTM  0.719632  0.705877  0.639060"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   Means\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NSE_1D</th>\n",
       "      <th>NSE_12H</th>\n",
       "      <th>NSE_1H</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MTS-LSTM</th>\n",
       "      <td>0.664499</td>\n",
       "      <td>0.672290</td>\n",
       "      <td>0.633878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ODE-LSTM</th>\n",
       "      <td>0.650770</td>\n",
       "      <td>0.638408</td>\n",
       "      <td>0.591831</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            NSE_1D   NSE_12H    NSE_1H\n",
       "MTS-LSTM  0.664499  0.672290  0.633878\n",
       "ODE-LSTM  0.650770  0.638408  0.591831"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(B) Train on 1H, 3H. Evaluate on 1D. Medians:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NSE_1H</th>\n",
       "      <th>NSE_3H</th>\n",
       "      <th>NSE_1D</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MTS-LSTM</th>\n",
       "      <td>0.700421</td>\n",
       "      <td>0.727794</td>\n",
       "      <td>0.745852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ODE-LSTM</th>\n",
       "      <td>0.677459</td>\n",
       "      <td>0.674996</td>\n",
       "      <td>0.586775</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            NSE_1H    NSE_3H    NSE_1D\n",
       "MTS-LSTM  0.700421  0.727794  0.745852\n",
       "ODE-LSTM  0.677459  0.674996  0.586775"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   Means\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>NSE_1H</th>\n",
       "      <th>NSE_3H</th>\n",
       "      <th>NSE_1D</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MTS-LSTM</th>\n",
       "      <td>0.633374</td>\n",
       "      <td>0.672315</td>\n",
       "      <td>0.718217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ODE-LSTM</th>\n",
       "      <td>0.585862</td>\n",
       "      <td>0.592676</td>\n",
       "      <td>0.546010</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            NSE_1H    NSE_3H    NSE_1D\n",
       "MTS-LSTM  0.633374  0.672315  0.718217\n",
       "ODE-LSTM  0.585862  0.592676  0.546010"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print('(A) Train on 1D, 12H. Evaluate on 1H. Medians:')\n",
    "display(a_df.median(axis=0, level=0))\n",
    "print('   Means')\n",
    "display(a_df.mean(axis=0, level=0))\n",
    "\n",
    "print('(B) Train on 1H, 3H. Evaluate on 1D. Medians:')\n",
    "display(b_df.median(axis=0, level=0))\n",
    "print('   Means')\n",
    "display(b_df.mean(axis=0, level=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
